{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ISubpr_SSsiM"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "3jTMb1dySr3V"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6DWfyNThSziV"
      },
      "source": [
        "# 使用 tf.function 提高性能\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://tensorflow.google.cn/guide/function\" class=\"\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\" class=\"\">在 TensorFlow.org 上查看</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/function.ipynb\" class=\"\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\" class=\"\">在 Google Colab 中运行</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/function.ipynb\" class=\"\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\" class=\"\">在 GitHub 上查看源代码</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/guide/function.ipynb\" class=\"\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\" class=\"\">下载笔记本</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J122XQYG7W6w"
      },
      "source": [
        "在 TensorFlow 2 中，默认情况下会打开 Eager Execution 模式。这种模式下的用户界面非常灵活直观（执行一次性运算要简单快速得多），但可能会牺牲一定的性能和可部署性。\n",
        "\n",
        "您可以使用 `tf.function` 将程序转换为计算图。这是一个转换工具，用于从 Python 代码创建独立于 Python 的数据流图。它可以帮助您创建高效且可移植的模型，并且如果要使用 `SavedModel`，则必须使用此工具。\n",
        "\n",
        "本指南介绍 `tf.function` 的底层工作原理，让您形成概念化理解，从而有效地加以利用。\n",
        "\n",
        "要点和建议包括：\n",
        "\n",
        "- 先在 Eager 模式下调试，然后使用 `@tf.function` 进行装饰。\n",
        "- 不依赖 Python 的副作用，如对象变异或列表追加。\n",
        "- `tf.function` 最适合处理 TensorFlow 运算；NumPy 和 Python 调用会转换为常量。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SjvqpgepHJPd"
      },
      "source": [
        "## 设置"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "otIdN1TS8N7S"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I0xDjO4SHLUD"
      },
      "source": [
        "定义一个辅助函数来演示可能遇到的错误类型："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "D25apou9IOXa"
      },
      "outputs": [],
      "source": [
        "import traceback\n",
        "import contextlib\n",
        "\n",
        "# Some helper code to demonstrate the kinds of errors you might encounter.\n",
        "@contextlib.contextmanager\n",
        "def assert_raises(error_class):\n",
        "  try:\n",
        "    yield\n",
        "  except error_class as e:\n",
        "    print('Caught expected exception \\n  {}:'.format(error_class))\n",
        "    traceback.print_exc(limit=2)\n",
        "  except Exception as e:\n",
        "    raise e\n",
        "  else:\n",
        "    raise Exception('Expected {} to be raised but no error was raised!'.format(\n",
        "        error_class))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WPSfepzTHThq"
      },
      "source": [
        "## 基础知识"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CNwYTIJ8r56W"
      },
      "source": [
        "### 用法\n",
        "\n",
        "您定义的 `Function` 就像核心 TensorFlow 运算：您可以在 Eager 模式下执行，可以计算梯度，等等。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SbtT1-Wm70F2"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def add(a, b):\n",
        "  return a + b\n",
        "\n",
        "add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uP-zUelB8DbX"
      },
      "outputs": [],
      "source": [
        "v = tf.Variable(1.0)\n",
        "with tf.GradientTape() as tape:\n",
        "  result = add(v, 1.0)\n",
        "tape.gradient(result, v)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ocWZvqrmHnmX"
      },
      "source": [
        "`Function` 中可以嵌套其他 `Function`。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l5qRjdbBVdU6"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def dense_layer(x, w, b):\n",
        "  return add(tf.matmul(x, w), b)\n",
        "\n",
        "dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "piBhz7gYsHqU"
      },
      "source": [
        "`Function` 的执行速度比 Eager 代码快，尤其是对于包含很多简单运算的计算图。但是，对于包含一些复杂运算（如卷积）的计算图，速度提升不会太明显。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zuXt4wRysI03"
      },
      "outputs": [],
      "source": [
        "import timeit\n",
        "conv_layer = tf.keras.layers.Conv2D(100, 3)\n",
        "\n",
        "@tf.function\n",
        "def conv_fn(image):\n",
        "  return conv_layer(image)\n",
        "\n",
        "image = tf.zeros([1, 200, 200, 100])\n",
        "# warm up\n",
        "conv_layer(image); conv_fn(image)\n",
        "print(\"Eager conv:\", timeit.timeit(lambda: conv_layer(image), number=10))\n",
        "print(\"Function conv:\", timeit.timeit(lambda: conv_fn(image), number=10))\n",
        "print(\"Note how there's not much difference in performance for convolutions\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uZ4Do2AV80cO"
      },
      "source": [
        "### 跟踪\n",
        "\n",
        "Python 的动态类型意味着您可以调用包含各种参数类型的函数，在各种场景下，Python 的行为可能有所不同。\n",
        "\n",
        "但是，创建 TensorFlow 计算图需要静态 `dtype` 和形状维度。`tf.function` 通过包装一个 Python 函数来创建 `Function` 对象，弥补了这一缺陷。根据提供的输入，`Function` 为其选择相应的计算图，从而在必要时追溯 Python 函数。理解发生跟踪的原因和时机后，有效运用 `tf.function` 就会容易得多！\n",
        "\n",
        "您可以通过调用包含不同类型参数的 `Function` 来切实观察这种多态行为。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kojmJrgq8U9v"
      },
      "outputs": [],
      "source": [
        "# Functions are polymorphic\n",
        "\n",
        "@tf.function\n",
        "def double(a):\n",
        "  print(\"Tracing with\", a)\n",
        "  return a + a\n",
        "\n",
        "print(double(tf.constant(1)))\n",
        "print()\n",
        "print(double(tf.constant(1.1)))\n",
        "print()\n",
        "print(double(tf.constant(\"a\")))\n",
        "print()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QPfouGUQrcNb"
      },
      "source": [
        "请注意，如果重复调用包含相同参数类型的 `Function`，TensorFlow 会重复使用之前跟踪的计算图，因为后面的调用生成的计算图将相同。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hFccbWFRrsBp"
      },
      "outputs": [],
      "source": [
        "# This doesn't print 'Tracing with ...'\n",
        "print(double(tf.constant(\"b\")))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fgIO_XEzcB9o"
      },
      "source": [
        "（以下更改存在于 TensorFlow Nightly 版本中，并且将在 TensorFlow 2.3 中提供。）\n",
        "\n",
        "您可以使用 `pretty_printed_concrete_signatures()` 查看所有可用跟踪："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IiQc4IKAb-NX"
      },
      "outputs": [],
      "source": [
        "print(double.pretty_printed_concrete_signatures())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rKQ92VEWI7n8"
      },
      "source": [
        "目前，您已经了解 `tf.function` 通过 TensorFlow 的计算图跟踪逻辑创建缓存的动态调度层。对于术语的含义，更具体的解释如下：\n",
        "\n",
        "- `tf.Graph` 与语言无关，是对计算的原始可移植表示。\n",
        "- `ConcreteFunction` 是 `tf.Graph` 的 Eeager 执行包装器。\n",
        "- `Function` 管理 `ConcreteFunction` 的缓存，并为输入选择正确的缓存。\n",
        "- `tf.function` 包装 Python 函数，并返回一个 `Function` 对象。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "96IxS2WR37fF"
      },
      "source": [
        "### 获取具体函数\n",
        "\n",
        "每次跟踪函数时都会创建一个新的具体函数。您可以使用 `get_concrete_function` 直接获取具体函数。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mHg2CGtPQ3Hz"
      },
      "outputs": [],
      "source": [
        "print(\"Obtaining concrete trace\")\n",
        "double_strings = double.get_concrete_function(tf.constant(\"a\"))\n",
        "print(\"Executing traced function\")\n",
        "print(double_strings(tf.constant(\"a\")))\n",
        "print(double_strings(a=tf.constant(\"b\")))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6IVZ-NVf9vsx"
      },
      "outputs": [],
      "source": [
        "# You can also call get_concrete_function on an InputSpec\n",
        "double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))\n",
        "print(double_strings_from_inputspec(tf.constant(\"c\")))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iR4fVmG34xvF"
      },
      "source": [
        "（以下更改存在于 TensorFlow Nightly 版本中，并且将在 TensorFlow 2.3 中提供。）\n",
        "\n",
        "打印 `ConcreteFunction` 会显示其输入参数（及类型）和输出类型的摘要。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "o3-JbkIk41r8"
      },
      "outputs": [],
      "source": [
        "print(double_strings)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QtqfvljZeuOV"
      },
      "source": [
        "您也可以直接检索具体函数的签名。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nzbrqFABe0zG"
      },
      "outputs": [],
      "source": [
        "print(double_strings.structured_input_signature)\n",
        "print(double_strings.structured_outputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lar5A_5m5IG1"
      },
      "source": [
        "对不兼容的类型使用具体跟踪会引发错误"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "G5eeTK-T5KYj"
      },
      "outputs": [],
      "source": [
        "with assert_raises(tf.errors.InvalidArgumentError):\n",
        "  double_strings(tf.constant(1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "st2L9VNQVtSG"
      },
      "source": [
        "您可能会注意到，在具体函数的输入签名中对 Python 参数进行了特别处理。TensorFlow 2.3 之前的版本会将 Python 参数直接从具体函数的签名中删除。从 TensorFlow 2.3 开始，Python 参数会保留在签名中，但是会受到约束，只能获取在跟踪期间设置的值。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U_QyPSGoaC35"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def pow(a, b):\n",
        "  return a ** b\n",
        "\n",
        "square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)\n",
        "print(square)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "E76vIDhQbXIb"
      },
      "outputs": [],
      "source": [
        "assert square(tf.constant(10.0)) == 100\n",
        "\n",
        "with assert_raises(TypeError):\n",
        "  square(tf.constant(10.0), b=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "41gJh_JGIfuA"
      },
      "source": [
        "### 获取计算图\n",
        "\n",
        "每个具体函数都是 `tf.Graph` 的可调用包装器。虽然一般不需要检索实际 `tf.Graph` 对象，不过，您可以从任何具体函数轻松获得实际对象。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5UENeGHfaX8g"
      },
      "outputs": [],
      "source": [
        "graph = double_strings.graph\n",
        "for node in graph.as_graph_def().node:\n",
        "  print(f'{node.input} -&gt; {node.name}')\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aIKkgr6qdtp4"
      },
      "source": [
        "### 调试\n",
        "\n",
        "通常，在 Eager 模式下调试代码比在 `tf.function` 中简单。在使用 `tf.function` 进行装饰之前，进行装饰之前，您应该先确保代码可在 Eager 模式下无错误执行。为了帮助调试，您可以调用 `tf.config.run_functions_eagerly(True)` 来全局停用和重新启用 `tf.function`。\n",
        "\n",
        "追溯仅在 `tf.function` 中出现的问题时，可参考下面的几点提示：\n",
        "\n",
        "- 普通旧 Python `print` 调用仅在跟踪期间执行，可用于追溯（重新）跟踪函数的时间。\n",
        "- `tf.print` 调用每次都会执行，可用于追溯执行过程中产生的中间值。\n",
        "- 利用 `tf.debugging.enable_check_numerics` 很容易追溯到 NaN 和 Inf 在何处创建。\n",
        "- `pdb` 可以帮助您理解跟踪的详细过程。（提醒：使用 PDB 调试时，AutoGraph 会自动转换 Python 源代码。）"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "129-iRsPS-gY"
      },
      "source": [
        "## 跟踪语义"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h62XoXho6EWN"
      },
      "source": [
        "### 缓存键规则\n",
        "\n",
        "通过从输入的参数和关键词参数计算缓存键，`Function` 可以确定是否重复使用跟踪的具体函数。\n",
        "\n",
        "- 为 `tf.Tensor` 参数生成的键是其形状和 dtype。\n",
        "- 从 TensorFlow 2.3 开始，为 `tf.Variable` 参数生成的键是其 `id()`。\n",
        "- 为 Python 基元生成的键是其值。为嵌套 `dict`、 `list`、 `tuple`、 `namedtuple` 和 [`attr`](https://www.attrs.org/en/stable/) 生成的键是扁平化元祖。（由于这种扁平化处理，如果调用的具体函数的嵌套结构与跟踪期间使用的不同，则会导致 TypeError）。\n",
        "- 对于所有其他 Python 类型，键基于对象 `id()`，以便为类的每个实例独立跟踪方法。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PEDwbumO32Wh"
      },
      "source": [
        "### 控制回溯\n",
        "\n",
        "回溯可以确保 TensorFlow 为每组输入生成正确的计算图。但是，跟踪操作非常消耗资源！如果 `Function` 为每一次调用都回溯新的计算图，您会发现代码的执行速度远不如不使用 `tf.function`。\n",
        "\n",
        "要控制跟踪行为，可以采用以下技巧："
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EUtycWJa34TT"
      },
      "source": [
        "- 在 `tf.function` 中指定 `input_signature` 来限制跟踪。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_BDMIRmu1RGB"
      },
      "outputs": [],
      "source": [
        "@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))\n",
        "def next_collatz(x):\n",
        "  print(\"Tracing with\", x)\n",
        "  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)\n",
        "\n",
        "print(next_collatz(tf.constant([1, 2])))\n",
        "# We specified a 1-D tensor in the input signature, so this should fail.\n",
        "with assert_raises(ValueError):\n",
        "  next_collatz(tf.constant([[1, 2], [3, 4]]))\n",
        "\n",
        "# We specified an int32 dtype in the input signature, so this should fail.\n",
        "with assert_raises(ValueError):\n",
        "  next_collatz(tf.constant([1.0, 2.0]))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ocxX-HVk7P2o"
      },
      "source": [
        "- 在 `tf.TensorSpec` 中指定 [None] 维度可灵活运用跟踪重用。\n",
        "\n",
        "    由于 TensorFlow 根据其形状匹配张量，因此，对于可变大小输入，使用 `None` 维度作为通配符可以让 `Function` 重复使用跟踪。对于每个批次，如果有不同长度的序列或不同大小的计算图，则会出现可变大小输入（请参阅 [Transformer](https://render.githubusercontent.com/tutorials/text/transformer.ipynb) 和 [Deep Dream](https://render.githubusercontent.com/tutorials/generative/deepdream.ipynb) 教程了解示例）。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4Viun7dh7PmF"
      },
      "outputs": [],
      "source": [
        "@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))\n",
        "def g(x):\n",
        "  print('Tracing with', x)\n",
        "  return x\n",
        "\n",
        "# No retrace!\n",
        "print(g(tf.constant([1, 2, 3])))\n",
        "print(g(tf.constant([1, 2, 3, 4, 5])))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AY5oiQN0XIyA"
      },
      "source": [
        "- 将 Python 参数转换为张量以减少回溯。\n",
        "\n",
        "    通常，Python 参数用于控制超参数和计算图构造，例如 `num_layers=10`、`training=True` 或 `nonlinearity='relu'`。所以，如果 Python 参数改变，则有必要回溯计算图。\n",
        "\n",
        "    但是，Python 参数有可能并未用于控制计算图构造。在这些情况下，Python 值的改变可能触发非必要的回溯。例如，在此训练循环中，AutoGraph 会动态展开。尽管有多个跟踪，但生成的计算图实际上是相同的，所以没有必要进行回溯。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uydzR5JYUU8H"
      },
      "outputs": [],
      "source": [
        "def train_one_step():\n",
        "  pass\n",
        "\n",
        "@tf.function\n",
        "def train(num_steps):\n",
        "  print(\"Tracing with num_steps = \", num_steps)\n",
        "  tf.print(\"Executing with num_steps = \", num_steps)\n",
        "  for _ in tf.range(num_steps):\n",
        "    train_one_step()\n",
        "\n",
        "print(\"Retracing occurs for different Python arguments.\")\n",
        "train(num_steps=10)\n",
        "train(num_steps=20)\n",
        "\n",
        "print()\n",
        "print(\"Traces are reused for Tensor arguments.\")\n",
        "train(num_steps=tf.constant(10))\n",
        "train(num_steps=tf.constant(20))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4pJqkDR_Q2wz"
      },
      "source": [
        "如果需要强制执行回溯，可以创建一个新的 `Function`。单独的 `Function` 对象肯定不会共享跟踪记录。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uHp4ousu4DdN"
      },
      "outputs": [],
      "source": [
        "def f():\n",
        "  print('Tracing!')\n",
        "  tf.print('Executing')\n",
        "\n",
        "tf.function(f)()\n",
        "tf.function(f)()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EJqHGFSVLIKl"
      },
      "source": [
        "### Python 副作用\n",
        "\n",
        "Python 副作用（如打印、追加到列表、改变全局变量）仅在第一次使用一组输入调用 `Function` 时才会发生。随后重新执行跟踪的 `tf.Graph`，而不执行 Python 代码。\n",
        "\n",
        "一般经验法则是仅使用 Python 副作用来调试跟踪记录。另外，对于每一次调用，TensorFlow 运算（如 `tf.Variable.assign`、`tf.print` 和 `tf.summary`）是确保代码得到 TensorFlow 运行时跟踪并执行的最佳方法。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "w2sACuZ9TTRk"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def f(x):\n",
        "  print(\"Traced with\", x)\n",
        "  tf.print(\"Executed with\", x)\n",
        "\n",
        "f(1)\n",
        "f(1)\n",
        "f(2)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "msTmv-oyUNaf"
      },
      "source": [
        "很多 Python 功能（如生成器和迭代器）依赖 Python 运行时来跟踪状态。通常，虽然这些构造在 Eager 模式下可以正常工作，但由于跟踪行为，`tf.function` 中会发生许多意外情况：\n",
        "\n",
        "举一个例子，推进迭代器状态是 Python 的一个副作用，因此只在跟踪过程中发生。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FNPD4unZUedH"
      },
      "outputs": [],
      "source": [
        "external_var = tf.Variable(0)\n",
        "@tf.function\n",
        "def buggy_consume_next(iterator):\n",
        "  external_var.assign_add(next(iterator))\n",
        "  tf.print(\"Value of external_var:\", external_var)\n",
        "\n",
        "iterator = iter([0, 1, 2, 3])\n",
        "buggy_consume_next(iterator)\n",
        "# This reuses the first value from the iterator, rather than consuming the next value.\n",
        "buggy_consume_next(iterator)\n",
        "buggy_consume_next(iterator)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wcS3TAgCjTWR"
      },
      "source": [
        "某些迭代构造通过 AutoGraph 获得支持。有关概述，请参阅 [AutoGraph 转换](#autograph_transformations)部分。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e1I0dPiqTV8H"
      },
      "source": [
        "如果希望在每次调用 `Function` 时都执行 Python 代码，`tf.py_function` 可以作为退出舱口。`tf.py_function` 的缺点是不可移植，性能不高，并且在分布式（多 GPU、TPU）设置中效果不佳。另外，由于 `tf.py_function` 必须连接到计算图中，它会将所有输入/输出转换为张量。\n",
        "\n",
        "`tf.gather`、`tf.stack` 和 `tf.TensorArray` 之类的 API 可帮助您在原生 TensorFlow 中实现常见循环模式。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7aJD--9qTWmg"
      },
      "outputs": [],
      "source": [
        "external_list = []\n",
        "\n",
        "def side_effect(x):\n",
        "  print('Python side effect')\n",
        "  external_list.append(x)\n",
        "\n",
        "@tf.function\n",
        "def f(x):\n",
        "  tf.py_function(side_effect, inp=[x], Tout=[])\n",
        "\n",
        "f(1)\n",
        "f(1)\n",
        "f(1)\n",
        "# The list append happens all three times!\n",
        "assert len(external_list) == 3\n",
        "# The list contains tf.constant(1), not 1, because py_function casts everything to tensors.\n",
        "assert external_list[0].numpy() == 1\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lPr_6mK_AQWL"
      },
      "source": [
        "### 变量\n",
        "\n",
        "在函数中创建新的 `tf.Variable` 时可能遇到错误。该错误是为了防止重复调用发生行为背离：在 Eager 模式下，每次调用函数时都会创建一个新变量，但是在 `Function` 中则不一定，这是因为重复使用了跟踪记录。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Tx0Vvnb_9OB-"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def f(x):\n",
        "  v = tf.Variable(1.0)\n",
        "  v.assign_add(x)\n",
        "  return v\n",
        "\n",
        "with assert_raises(ValueError):\n",
        "  f(1.0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KYm6-5GCILXQ"
      },
      "source": [
        "您也可以在 `Function` 内部创建变量，不过只能在第一次执行该函数时创建这些变量。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HQrG5_kOiKl_"
      },
      "outputs": [],
      "source": [
        "class Count(tf.Module):\n",
        "  def __init__(self):\n",
        "    self.count = None\n",
        "  \n",
        "  @tf.function\n",
        "  def __call__(self):\n",
        "    if self.count is None:\n",
        "      self.count = tf.Variable(0)\n",
        "    return self.count.assign_add(1)\n",
        "\n",
        "c = Count()\n",
        "print(c())\n",
        "print(c())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZoPg5w1Pjqna"
      },
      "source": [
        "您可能遇到的另一个错误是变量被回收。与常规 Python 函数不同，具体函数只会保留对它们闭包时所在变量的[弱引用](https://docs.python.org/3/library/weakref.html)，因此，您必须保留对任何变量的引用。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uMiRPfETjpt-"
      },
      "outputs": [],
      "source": [
        "external_var = tf.Variable(3)\n",
        "@tf.function\n",
        "def f(x):\n",
        "  return x * external_var\n",
        "\n",
        "traced_f = f.get_concrete_function(4)\n",
        "print(\"Calling concrete function...\")\n",
        "print(traced_f(4))\n",
        "\n",
        "del external_var\n",
        "print()\n",
        "print(\"Calling concrete function after garbage collecting its closed Variable...\")\n",
        "with assert_raises(tf.errors.FailedPreconditionError):\n",
        "  traced_f(4)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5f05Vr_YBUCz"
      },
      "source": [
        "## AutoGraph 转换\n",
        "\n",
        "AutoGraph 是一个库，在 `tf.function` 中默认处于启用状态。它可以将 Python Eager 代码的子集转换为与计算图兼容的 TensorFlow 运算。这包括 `if`、`for`、`while` 等控制流。\n",
        "\n",
        "`tf.cond` 和 `tf.while_loop` 等 TensorFlow 运算仍然可以运行，但是使用 Python 编写时，控制流通常更易于编写，代码也更易于理解。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yCQTtTPTW3WF"
      },
      "outputs": [],
      "source": [
        "# Simple loop\n",
        "\n",
        "@tf.function\n",
        "def f(x):\n",
        "  while tf.reduce_sum(x) &gt; 1:\n",
        "    tf.print(x)\n",
        "    x = tf.tanh(x)\n",
        "  return x\n",
        "\n",
        "f(tf.random.uniform([5]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KxwJ8znPI0Cg"
      },
      "source": [
        "如果您有兴趣，可以检查 Autograph 生成的代码。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jlQD1ffRXJhl"
      },
      "outputs": [],
      "source": [
        "print(tf.autograph.to_code(f.python_function))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xgKmkrNTZSyz"
      },
      "source": [
        "### 条件语句\n",
        "\n",
        "AutoGraph 会将某些 `if <condition>` 语句转换为等效的 `tf.cond` 调用。如果 `<condition>` 是张量，则会执行这种替换，否则会将 `if` 语句作为 Python 条件语句执行。\n",
        "\n",
        "Python 条件语句在跟踪时执行，因此会将该条件语句的一个分支添加到计算图。如果不使用 AutoGraph，当存在依赖于数据的控制流时，此跟踪计算图将无法选择替代分支。\n",
        "\n",
        "`tf.cond` 跟踪并将条件的两个分支添加到计算图，在执行时动态选择分支。跟踪可能产生意外的副作用；有关详细信息，请参阅 [AutoGraph 跟踪作用](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#effects-of-the-tracing-process)。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BOQl8PMq2Sf3"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def fizzbuzz(n):\n",
        "  for i in tf.range(1, n + 1):\n",
        "    print('Tracing for loop')\n",
        "    if i % 15 == 0:\n",
        "      print('Tracing fizzbuzz branch')\n",
        "      tf.print('fizzbuzz')\n",
        "    elif i % 3 == 0:\n",
        "      print('Tracing fizz branch')\n",
        "      tf.print('fizz')\n",
        "    elif i % 5 == 0:\n",
        "      print('Tracing buzz branch')\n",
        "      tf.print('buzz')\n",
        "    else:\n",
        "      print('Tracing default branch')\n",
        "      tf.print(i)\n",
        "\n",
        "fizzbuzz(tf.constant(5))\n",
        "fizzbuzz(tf.constant(20))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4rBO5AQ15HVC"
      },
      "source": [
        "有关 AutoGraph 转换的 if 语句的其他限制，请参阅[参考文档](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#if-statements)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yho4J0a0ZkQS"
      },
      "source": [
        "### 循环\n",
        "\n",
        "AutoGraph 会将某些 `for` 和 `while` 语句转换为等效的 TensorFlow 循环运算，例如 `tf.while_loop`。如果不转换，则会将 `for` 或 `while` 循环作为 Python 循环执行。\n",
        "\n",
        "以下情形会执行这种替换：\n",
        "\n",
        "- `for x in y`：如果 `y` 是一个张量，则转换为 `tf.while_loop`。在特殊情况下，如果 `y` 是 `tf.data.Dataset`，则会生成 `tf.data.Dataset` 运算的组合。\n",
        "- `while <condition>`：如果 `<condition>` 是张量，则转换为 `tf.while_loop`。\n",
        "\n",
        "Python 循环在跟踪时执行，因而循环每迭代一次，都会将额外的运算添加到 `tf.Graph`。\n",
        "\n",
        "TensorFlow 循环会跟踪循环体，并在执行时动态选择迭代的运行次数。循环体仅在生成的 `tf.Graph` 中出现一次。\n",
        "\n",
        "有关 AutoGraph 转换的 `for` 和 `while` 语句的其他限制，请参阅[参考文档](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#while-statements)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sp4rbIdfbM6s"
      },
      "source": [
        "#### 在 Python 数据上循环\n",
        "\n",
        "一个常见陷阱是在 `tf.function` 中的 Python/Numpy 数据上循环。此循环在跟踪过程中执行，因而循环每迭代一次，都会将模型的一个副本添加到 `tf.Graph`。\n",
        "\n",
        "如果要在 `tf.function` 中包装整个训练循环，最安全的方法是将数据包装为 `tf.data.Dataset`，以便 AutoGraph 动态展开训练循环。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WGZ19LspbZ27"
      },
      "outputs": [],
      "source": [
        "def measure_graph_size(f, *args):\n",
        "  g = f.get_concrete_function(*args).graph\n",
        "  print(\"{}({}) contains {} nodes in its graph\".format(\n",
        "      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))\n",
        "\n",
        "@tf.function\n",
        "def train(dataset):\n",
        "  loss = tf.constant(0)\n",
        "  for x, y in dataset:\n",
        "    loss += tf.abs(y - x) # Some dummy computation.\n",
        "  return loss\n",
        "\n",
        "small_data = [(1, 1)] * 3\n",
        "big_data = [(1, 1)] * 10\n",
        "measure_graph_size(train, small_data)\n",
        "measure_graph_size(train, big_data)\n",
        "\n",
        "measure_graph_size(train, tf.data.Dataset.from_generator(\n",
        "    lambda: small_data, (tf.int32, tf.int32)))\n",
        "measure_graph_size(train, tf.data.Dataset.from_generator(\n",
        "    lambda: big_data, (tf.int32, tf.int32)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JeD2U-yrbfVb"
      },
      "source": [
        "在数据集中包装 Python/Numpy 数据时，要注意 `tf.data.Dataset.from_generator` 与 ` tf.data.Dataset.from_tensors`。前者将数据保留在 Python 中，并通过 `tf.py_function` 获取，这可能会影响性能；后者将数据的副本捆绑成计算图中的一个大 `tf.constant()` 节点，这可能会消耗较多内存。\n",
        "\n",
        "通过 TFRecordDataset/CsvDataset 等从文件中读取数据是最高效的数据使用方式，因为这样 TensorFlow 就可以自行管理数据的异步加载和预提取，不必利用 Python。要了解详细信息，请参阅 [tf.data 指南](https://render.githubusercontent.com/guide/data)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hyksHW9TCukR"
      },
      "source": [
        "#### 累加循环值\n",
        "\n",
        "一种常见模式是不断累加循环的中间值。通常，这可以通过将元素追加到 Python 列表或将条目添加到 Python 字典来实现。但是，由于存在 Python 副作用，在动态展开循环中，这些方法无法达到预期效果。要从动态展开循环累加结果，可以使用 `tf.TensorArray` 来实现。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HJ3Vb3dXfefN"
      },
      "outputs": [],
      "source": [
        "batch_size = 2\n",
        "seq_len = 3\n",
        "feature_size = 4\n",
        "\n",
        "def rnn_step(inp, state):\n",
        "  return inp + state\n",
        "\n",
        "@tf.function\n",
        "def dynamic_rnn(rnn_step, input_data, initial_state):\n",
        "  # [batch, time, features] -&gt; [time, batch, features]\n",
        "  input_data = tf.transpose(input_data, [1, 0, 2])\n",
        "  max_seq_len = input_data.shape[0]\n",
        "\n",
        "  states = tf.TensorArray(tf.float32, size=max_seq_len)\n",
        "  state = initial_state\n",
        "  for i in tf.range(max_seq_len):\n",
        "    state = rnn_step(input_data[i], state)\n",
        "    states = states.write(i, state)\n",
        "  return tf.transpose(states.stack(), [1, 0, 2])\n",
        "  \n",
        "dynamic_rnn(rnn_step,\n",
        "            tf.random.uniform([batch_size, seq_len, feature_size]),\n",
        "            tf.zeros([batch_size, feature_size]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IKyrEY5GVX3M"
      },
      "source": [
        "## 延伸阅读\n",
        "\n",
        "要了解如何导出和加载 `Function`，请参阅 [SavedModel 指南](https://render.githubusercontent.com/guide/saved_model)。要详细了解跟踪后执行的计算图优化，请参阅 [Grappler 指南](https://render.githubusercontent.com/guide/graph_optimization)。要了解如何优化数据流水线和分析模型，请参阅 [Profiler 指南](https://render.githubusercontent.com/guide/profiler.md)。"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "function.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
